#!/usr/bin/env python

# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import os
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, NamedTuple

import numpy as np

from onnx import defs, helper
from onnx.backend.sample.ops import collect_sample_implementations
from onnx.backend.test.case import collect_snippets
from onnx.defs import ONNX_ML_DOMAIN, OpSchema

SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = not bool(os.getenv("ONNX_ML") == "0")


def display_number(v: int) -> str:
    if defs.OpSchema.is_infinite(v):
        return "&#8734;"
    return str(v)


def should_render_domain(domain: str, output: str) -> bool:
    is_ml = "-ml" in output
    if domain == ONNX_ML_DOMAIN:
        return is_ml
    else:
        return not is_ml


def format_name_with_domain(domain: str, schema_name: str) -> str:
    if domain:
        return f"{domain}.{schema_name}"
    return schema_name


def format_function_versions(function_versions: Sequence[int]) -> str:
    return f"{', '.join([str(v) for v in function_versions])}"


def format_versions(versions: Sequence[OpSchema], changelog: str) -> str:
    return f"{', '.join(display_version_link(format_name_with_domain(v.domain, v.name), v.since_version, changelog) for v in versions[::-1])}"


def display_attr_type(v: OpSchema.AttrType) -> str:
    assert isinstance(v, OpSchema.AttrType)
    s = str(v)
    s = s[s.rfind(".") + 1 :].lower()
    if s[-1] == "s":
        s = "list of " + s
    return s


def display_domain(domain: str) -> str:
    if domain:
        return f"the '{domain}' operator set"
    return "the default ONNX operator set"


def display_domain_short(domain: str) -> str:
    if domain:
        return domain
    return "ai.onnx (default)"


def display_version_link(name: str, version: int, changelog: str) -> str:
    name_with_ver = f"{name}-{version}"
    return f'<a href="{changelog}#{name_with_ver}">{version}</a>'


def generate_formal_parameter_tags(formal_parameter: OpSchema.FormalParameter) -> str:
    tags: list[str] = []
    if OpSchema.FormalParameterOption.Optional == formal_parameter.option:
        tags = ["optional"]
    elif OpSchema.FormalParameterOption.Variadic == formal_parameter.option:
        if formal_parameter.is_homogeneous:
            tags = ["variadic"]
        else:
            tags = ["variadic", "heterogeneous"]
    differentiable: OpSchema.DifferentiationCategory = (
        OpSchema.DifferentiationCategory.Differentiable
    )
    non_differentiable: OpSchema.DifferentiationCategory = (
        OpSchema.DifferentiationCategory.NonDifferentiable
    )
    if differentiable == formal_parameter.differentiation_category:
        tags.append("differentiable")
    elif non_differentiable == formal_parameter.differentiation_category:
        tags.append("non-differentiable")

    return "" if len(tags) == 0 else " (" + ", ".join(tags) + ")"


def display_schema(
    schema: OpSchema, versions: Sequence[OpSchema], changelog: str
) -> str:
    s = ""

    # doc
    if schema.doc:
        s += "\n"
        s += "\n".join(
            ("  " + line).rstrip() for line in schema.doc.lstrip().splitlines()
        )
        s += "\n"

    # since version
    s += "\n#### Version\n"
    if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
        s += "\nNo versioning maintained for experimental ops."
    else:
        s += (
            "\nThis version of the operator has been "
            + ("deprecated" if schema.deprecated else "available")
            + f" since version {schema.since_version}"
        )
        s += f" of {display_domain(schema.domain)}.\n"
        if len(versions) > 1:
            # TODO: link to the Changelog.md
            s += "\nOther versions of this operator: {}\n".format(
                ", ".join(
                    display_version_link(
                        format_name_with_domain(v.domain, v.name),
                        v.since_version,
                        changelog,
                    )
                    for v in versions[:-1]
                )
            )

    # If this schema is deprecated, don't display any of the following sections
    if schema.deprecated:
        return s

    # attributes
    if schema.attributes:
        s += "\n#### Attributes\n\n"
        s += "<dl>\n"
        for _, attr in sorted(schema.attributes.items()):
            # option holds either required or default value
            opt = ""
            if attr.required:
                opt = "required"
            elif attr.default_value.name:
                default_value = helper.get_attribute_value(attr.default_value)
                doc_string = attr.default_value.doc_string

                def format_value(value: Any) -> str:
                    if isinstance(value, float):
                        formatted = str(np.round(value, 5))
                        # use default formatting, unless too long.
                        if len(formatted) > 10:  # noqa: PLR2004
                            formatted = str(f"({value:e})")
                        return formatted
                    if isinstance(value, (bytes, bytearray)):
                        return str(value.decode("utf-8"))
                    return str(value)

                if isinstance(default_value, list):
                    default_value = [format_value(val) for val in default_value]
                else:
                    default_value = format_value(default_value)
                opt = f"default is {default_value}{doc_string}"

            s += f"<dt><tt>{attr.name}</tt> : {display_attr_type(attr.type)}{f' ({opt})' if opt else ''}</dt>\n"
            s += f"<dd>{attr.description}</dd>\n"
        s += "</dl>\n"

    # inputs
    s += "\n#### Inputs"
    if schema.min_input != schema.max_input:
        s += f" ({display_number(schema.min_input)} - {display_number(schema.max_input)})"
    s += "\n\n"
    if schema.inputs:
        s += "<dl>\n"
        for input_ in schema.inputs:
            option_str = generate_formal_parameter_tags(input_)
            s += f"<dt><tt>{input_.name}</tt>{option_str} : {input_.type_str}</dt>\n"
            s += f"<dd>{input_.description}</dd>\n"
        s += "</dl>\n"

    # outputs
    s += "\n#### Outputs"
    if schema.min_output != schema.max_output:
        s += f" ({display_number(schema.min_output)} - {display_number(schema.max_output)})"
    s += "\n\n"

    if schema.outputs:
        s += "<dl>\n"
        for output in schema.outputs:
            option_str = generate_formal_parameter_tags(output)
            s += f"<dt><tt>{output.name}</tt>{option_str} : {output.type_str}</dt>\n"
            s += f"<dd>{output.description}</dd>\n"
        s += "</dl>\n"

    # type constraints
    s += "\n#### Type Constraints"
    s += "\n\n"
    if schema.type_constraints:
        s += "<dl>\n"
        for type_constraint in schema.type_constraints:
            allowedTypes = type_constraint.allowed_type_strs
            if len(allowedTypes) > 0:
                allowedTypeStr = allowedTypes[0]
            for allowedType in allowedTypes[1:]:
                allowedTypeStr += ", " + allowedType
            s += f"<dt><tt>{type_constraint.type_param_str}</tt> : {allowedTypeStr}</dt>\n"
            s += f"<dd>{type_constraint.description}</dd>\n"
        s += "</dl>\n"

    # Function Body
    # TODO: this should be refactored to show the function body graph's picture (DAG).
    # if schema.has_function or schema.has_context_dependent_function:  # type: ignore
    #    s += '\n#### Function\n'
    #    s += '\nThe Function can be represented as a function.\n'

    return s


def support_level_str(level: OpSchema.SupportType) -> str:
    return (
        "<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
    )


class Args(NamedTuple):
    output: str
    changelog: str


def main(args: Args) -> None:
    base_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    )
    docs_dir = os.path.join(base_dir, "docs")

    with open(
        os.path.join(docs_dir, args.changelog), "w", newline="", encoding="utf-8"
    ) as fout:
        fout.write("<!--- SPDX-License-Identifier: Apache-2.0 -->\n")
        fout.write("## Operator Changelog\n")
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
            "\n"
            "For an operator input/output's differentiability, it can be differentiable,\n"
            "            non-differentiable, or undefined. If a variable's differentiability\n"
            "            is not specified, that variable has undefined differentiability.\n"
        )

        # domain -> version -> [schema]
        dv_index: dict[str, dict[int, list[OpSchema]]] = defaultdict(
            lambda: defaultdict(list)
        )
        for schema in defs.get_all_schemas_with_history():
            dv_index[schema.domain][schema.since_version].append(schema)

        fout.write("\n")

        for domain, versionmap in sorted(dv_index.items()):
            if not should_render_domain(domain, args.output):
                continue

            s = f"# {display_domain_short(domain)}\n"

            for version, unsorted_schemas in sorted(versionmap.items()):
                s += f"## Version {version} of {display_domain(domain)}\n"
                for schema in sorted(unsorted_schemas, key=lambda s: s.name):
                    name_with_ver = f"{format_name_with_domain(domain, schema.name)}-{schema.since_version}"
                    s += (
                        '### <a name="{}"></a>**{}**'
                        + (" (deprecated)" if schema.deprecated else "")
                        + "</a>\n"
                    ).format(name_with_ver, name_with_ver)
                    s += display_schema(schema, [schema], args.changelog)
                    s += "\n"

            fout.write(s)

    with open(
        os.path.join(docs_dir, args.output), "w", newline="", encoding="utf-8"
    ) as fout:
        fout.write("<!--- SPDX-License-Identifier: Apache-2.0 -->\n")
        fout.write("## Operator Schemas\n")
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n"
            "\n"
            "For an operator input/output's differentiability, it can be differentiable,\n"
            "            non-differentiable, or undefined. If a variable's differentiability\n"
            "            is not specified, that variable has undefined differentiability.\n"
        )

        # domain -> support level -> name -> [schema]
        index: dict[str, dict[int, dict[str, list[OpSchema]]]] = defaultdict(
            lambda: defaultdict(lambda: defaultdict(list))
        )
        for schema in defs.get_all_schemas_with_history():
            index[schema.domain][int(schema.support_level)][schema.name].append(schema)

        fout.write("\n")

        # Preprocess the Operator Schemas
        # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
        operator_schemas: list[
            tuple[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]
        ] = []
        existing_ops: set[str] = set()
        for domain, _supportmap in sorted(index.items()):
            if not should_render_domain(domain, args.output):
                continue

            processed_supportmap = []
            for _support, _namemap in sorted(_supportmap.items()):
                processed_namemap = []
                for n, unsorted_versions in sorted(_namemap.items()):
                    versions = sorted(unsorted_versions, key=lambda s: s.since_version)
                    schema = versions[-1]
                    if schema.name in existing_ops:
                        continue
                    existing_ops.add(schema.name)
                    processed_namemap.append((n, schema, versions))
                processed_supportmap.append((_support, processed_namemap))
            operator_schemas.append((domain, processed_supportmap))

        # Table of contents
        for domain, supportmap in operator_schemas:
            s = f"### {display_domain_short(domain)}\n"
            fout.write(s)

            fout.write("|**Operator**|**Since version**||\n")
            fout.write("|-|-|-|\n")

            function_ops = []
            for _, namemap in supportmap:
                for n, schema, versions in namemap:
                    if schema.has_function or schema.has_context_dependent_function:  # type: ignore
                        function_versions = schema.all_function_opset_versions  # type: ignore
                        function_ops.append((n, schema, versions, function_versions))
                        continue
                    s = '|{}<a href="#{}">{}</a>{}|{}|\n'.format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n),
                        " (deprecated)" if schema.deprecated else "",
                        format_versions(versions, args.changelog),
                    )
                    fout.write(s)
            if function_ops:
                fout.write("|**Function**|**Since version**|**Function version**|\n")
                for n, schema, versions, function_versions in function_ops:
                    s = '|{}<a href="#{}">{}</a>|{}|{}|\n'.format(  # noqa: UP032
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, n),
                        format_name_with_domain(domain, n),
                        format_versions(versions, args.changelog),
                        format_function_versions(function_versions),
                    )
                    fout.write(s)

            fout.write("\n")

        fout.write("\n")

        for domain, supportmap in operator_schemas:
            s = f"## {display_domain_short(domain)}\n"
            fout.write(s)

            for _, namemap in supportmap:
                for op_type, schema, versions in namemap:
                    # op_type
                    s = (
                        '### {}<a name="{}"></a><a name="{}">**{}**'
                        + (" (deprecated)" if schema.deprecated else "")
                        + "</a>\n"
                    ).format(
                        support_level_str(schema.support_level),
                        format_name_with_domain(domain, op_type),
                        format_name_with_domain(domain, op_type.lower()),
                        format_name_with_domain(domain, op_type),
                    )

                    s += display_schema(schema, versions, args.changelog)

                    s += "\n\n"

                    if op_type in SNIPPETS:
                        s += "#### Examples\n\n"
                        for summary, code in sorted(SNIPPETS[op_type]):
                            s += "<details>\n"
                            s += f"<summary>{summary}</summary>\n\n"
                            s += f"```python\n{code}\n```\n\n"
                            s += "</details>\n"
                            s += "\n\n"
                    if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
                        s += "#### Sample Implementation\n\n"
                        s += "<details>\n"
                        s += f"<summary>{op_type}</summary>\n\n"
                        s += f"```python\n{SAMPLE_IMPLEMENTATIONS[op_type.lower()]}\n```\n\n"
                        s += "</details>\n"
                        s += "\n\n"

                    fout.write(s)


if __name__ == "__main__":
    if ONNX_ML:
        main(
            Args(
                "Operators-ml.md",
                "Changelog-ml.md",
            )
        )
    main(
        Args(
            "Operators.md",
            "Changelog.md",
        )
    )
